import argparse
import torch
from torch.utils.data import DataLoader
import numpy as np
import os
from utils import test_accuracy, get_dataset, get_network
from tqdm import tqdm
from torchvision.transforms import v2


def args_parser():
    parser = argparse.ArgumentParser(description='inference for N shadow models on n data samples')
    parser.add_argument('--num_shadow', type=int, required=True, help='the number of shadow models')
    parser.add_argument('--lira_path', type=str, required=True, help='the folder path to save the LiRA results')
    parser.add_argument('--data_path', type=str, required=True, help='load cifar10 dataset')
    parser.add_argument('--method', type=str, choices=['cifar10', 'random', 'forgetting', 'DM', 'DSA', 'MTT', 'DATM', 'Diffusion', 'dpsgd', 'dfkd'], help='method of coreset selection or generating synthetic data')
    parser.add_argument('--model_type', type=str, choices=['ConvNet', 'ResNet18', 'ResNet18BN'], help='The model type to use')
    parser.add_argument('--augmentation', type=str, default='True', choices=['True', 'False'], help='whether to use data augmentation')
    parser.add_argument('--use_dd_aug', action='store_true', help='whether to use transforms in DD')
    parser.add_argument('--avg_case', action='store_true', default=False, help='use average case in-out split')
    parser.add_argument('--epoch', type=int, default=0, help='the epoch of the shadow models')
    parser.add_argument('--start', type=int, default=None, help='start index of the shadow models')
    parser.add_argument('--end', type=int, default=None, help='end index of the shadow models')
    args = parser.parse_args()

    # start and end should be specified together
    if args.start is None and args.end is not None:
        parser.error("--start should be specified together with --end")
    if args.start is not None and args.end is None:
        parser.error("--end should be specified together with --start")
    
    return args    


if __name__ == '__main__':
    args = args_parser()
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.augmentation = True if args.augmentation == 'True' else False

    suffix = f'{args.model_type}_dd_aug' if args.use_dd_aug else f'{args.model_type}'
    ckpt_path = os.path.join(args.lira_path, f"ckpts_{suffix}")

    score_path = os.path.join(args.lira_path, f"scores_{suffix}")
    if not os.path.exists(score_path):   
        os.mkdir(score_path) 
    
    noisy_targets = True if not args.avg_case else False
    zca = args.method in ['MTT', 'DATM']
    
    trainset, testset = get_dataset(
        args,
        noisy_targets=noisy_targets,
        zca=zca,
    )

    train_loader = DataLoader(trainset, batch_size=1024, shuffle=False, num_workers=4)
    test_loader = DataLoader(testset, batch_size=2048, shuffle=False, num_workers=4)

    start_id = args.start if args.start is not None else 0
    end_id = args.end if args.end is not None else args.num_shadow
    for idx in tqdm(range(start_id, end_id)):
        print({"processing model": idx})
        # model = get_network(args.model_type, channel=3, num_classes=10)
        model = torch.load(os.path.join(ckpt_path, f"model_epoch_{args.epoch}_{idx}.pt"))
        model.eval()
        scores = []
        with torch.no_grad():
            for img, labels in train_loader:
                scores_current = []
                img, labels = img.to(args.device), labels.to(args.device)
                if args.augmentation:
                    flip_augmentations = (False, True)
                    shift_augmentations = (0, -4, 4)
                else:
                    flip_augmentations = [False]
                    shift_augmentations = [0]
                batch_xs_pad = v2.functional.pad(
                    img,
                    padding=[4],
                )
                for flip in flip_augmentations:
                    for shift_y in shift_augmentations:
                        for shift_x in shift_augmentations:
                            offset_y = shift_y + 4
                            offset_x = shift_x + 4
                            batch_xs_aug = batch_xs_pad[:, :, offset_y : offset_y + 32, offset_x : offset_x + 32]
                            if flip:
                                batch_xs_aug = v2.functional.hflip(batch_xs_aug)
                            logits = model(batch_xs_aug)
                            # logits scaling
                            scaled_logits = logits - torch.unsqueeze(torch.max(logits, dim=1)[0], dim=1)
                            scaled_logits = scaled_logits.detach().cpu().numpy()
                            #! get softmax output, note that the data type is float64!
                            scaled_logits = np.array(np.exp(scaled_logits), dtype=np.float64)
                            softmax_output = scaled_logits/np.sum(scaled_logits,axis=1,keepdims=True)  
                            y_true = softmax_output[np.arange(labels.size(0)), labels.detach().cpu().numpy()]
                            y_wrong = 1 - y_true
                            score = (np.log(y_true+1e-45) - np.log(y_wrong+1e-45))
                            scores_current.append(score)
                scores.append(np.stack(scores_current, axis=1))

        # (num_samples, num_aug)
        np.save(os.path.join(score_path, f'score_epoch_{args.epoch}_{idx}.npy'), np.concatenate(scores, axis=0))
    